Employees are the lifeblood of any organization, directly impacting its performance and success. Understanding the dynamics of employee behavior, particularly attrition, is crucial for maintaining a stable and productive workforce. Employee attrition poses significant challenges, including the high cost of training new hires, loss of institutional knowledge, reduced productivity, and potential profit decline.
In this notebook, we will explore employee data to answer key business questions:
Let's dive into the data to uncover insights and strategies for reducing attrition and enhancing organizational performance.
import hvplot
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import hvplot.pandas
%matplotlib inline
sns.set_style('whitegrid')
plt.style.use('fivethirtyeight')
pd.set_option('display.float_format','{:.2f}'.format)
pd.set_option('display.max_columns',80)
pd.set_option('display.max_rows',80)
df = pd.read_csv('/Users/pritish/Documents/Project/data/HR DataAnalytics/WA_Fn-UseC_-HR-Employee-Attrition 2.csv')
df.head()
| Age | Attrition | BusinessTravel | DailyRate | Department | DistanceFromHome | Education | EducationField | EmployeeCount | EmployeeNumber | EnvironmentSatisfaction | Gender | HourlyRate | JobInvolvement | JobLevel | JobRole | JobSatisfaction | MaritalStatus | MonthlyIncome | MonthlyRate | NumCompaniesWorked | Over18 | OverTime | PercentSalaryHike | PerformanceRating | RelationshipSatisfaction | StandardHours | StockOptionLevel | TotalWorkingYears | TrainingTimesLastYear | WorkLifeBalance | YearsAtCompany | YearsInCurrentRole | YearsSinceLastPromotion | YearsWithCurrManager | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 41 | Yes | Travel_Rarely | 1102 | Sales | 1 | 2 | Life Sciences | 1 | 1 | 2 | Female | 94 | 3 | 2 | Sales Executive | 4 | Single | 5993 | 19479 | 8 | Y | Yes | 11 | 3 | 1 | 80 | 0 | 8 | 0 | 1 | 6 | 4 | 0 | 5 |
| 1 | 49 | No | Travel_Frequently | 279 | Research & Development | 8 | 1 | Life Sciences | 1 | 2 | 3 | Male | 61 | 2 | 2 | Research Scientist | 2 | Married | 5130 | 24907 | 1 | Y | No | 23 | 4 | 4 | 80 | 1 | 10 | 3 | 3 | 10 | 7 | 1 | 7 |
| 2 | 37 | Yes | Travel_Rarely | 1373 | Research & Development | 2 | 2 | Other | 1 | 4 | 4 | Male | 92 | 2 | 1 | Laboratory Technician | 3 | Single | 2090 | 2396 | 6 | Y | Yes | 15 | 3 | 2 | 80 | 0 | 7 | 3 | 3 | 0 | 0 | 0 | 0 |
| 3 | 33 | No | Travel_Frequently | 1392 | Research & Development | 3 | 4 | Life Sciences | 1 | 5 | 4 | Female | 56 | 3 | 1 | Research Scientist | 3 | Married | 2909 | 23159 | 1 | Y | Yes | 11 | 3 | 3 | 80 | 0 | 8 | 3 | 3 | 8 | 7 | 3 | 0 |
| 4 | 27 | No | Travel_Rarely | 591 | Research & Development | 2 | 1 | Medical | 1 | 7 | 1 | Male | 40 | 3 | 1 | Laboratory Technician | 2 | Married | 3468 | 16632 | 9 | Y | No | 12 | 3 | 4 | 80 | 1 | 6 | 3 | 3 | 2 | 2 | 2 | 2 |
Discovering patterns in the data is essential for understanding the underlying trends and relationships. Through data visualization, we can uncover hidden insights and make informed decisions.
df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 1470 entries, 0 to 1469 Data columns (total 35 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Age 1470 non-null int64 1 Attrition 1470 non-null object 2 BusinessTravel 1470 non-null object 3 DailyRate 1470 non-null int64 4 Department 1470 non-null object 5 DistanceFromHome 1470 non-null int64 6 Education 1470 non-null int64 7 EducationField 1470 non-null object 8 EmployeeCount 1470 non-null int64 9 EmployeeNumber 1470 non-null int64 10 EnvironmentSatisfaction 1470 non-null int64 11 Gender 1470 non-null object 12 HourlyRate 1470 non-null int64 13 JobInvolvement 1470 non-null int64 14 JobLevel 1470 non-null int64 15 JobRole 1470 non-null object 16 JobSatisfaction 1470 non-null int64 17 MaritalStatus 1470 non-null object 18 MonthlyIncome 1470 non-null int64 19 MonthlyRate 1470 non-null int64 20 NumCompaniesWorked 1470 non-null int64 21 Over18 1470 non-null object 22 OverTime 1470 non-null object 23 PercentSalaryHike 1470 non-null int64 24 PerformanceRating 1470 non-null int64 25 RelationshipSatisfaction 1470 non-null int64 26 StandardHours 1470 non-null int64 27 StockOptionLevel 1470 non-null int64 28 TotalWorkingYears 1470 non-null int64 29 TrainingTimesLastYear 1470 non-null int64 30 WorkLifeBalance 1470 non-null int64 31 YearsAtCompany 1470 non-null int64 32 YearsInCurrentRole 1470 non-null int64 33 YearsSinceLastPromotion 1470 non-null int64 34 YearsWithCurrManager 1470 non-null int64 dtypes: int64(26), object(9) memory usage: 402.1+ KB
df.describe()
| Age | DailyRate | DistanceFromHome | Education | EmployeeCount | EmployeeNumber | EnvironmentSatisfaction | HourlyRate | JobInvolvement | JobLevel | JobSatisfaction | MonthlyIncome | MonthlyRate | NumCompaniesWorked | PercentSalaryHike | PerformanceRating | RelationshipSatisfaction | StandardHours | StockOptionLevel | TotalWorkingYears | TrainingTimesLastYear | WorkLifeBalance | YearsAtCompany | YearsInCurrentRole | YearsSinceLastPromotion | YearsWithCurrManager | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 | 1470.00 |
| mean | 36.92 | 802.49 | 9.19 | 2.91 | 1.00 | 1024.87 | 2.72 | 65.89 | 2.73 | 2.06 | 2.73 | 6502.93 | 14313.10 | 2.69 | 15.21 | 3.15 | 2.71 | 80.00 | 0.79 | 11.28 | 2.80 | 2.76 | 7.01 | 4.23 | 2.19 | 4.12 |
| std | 9.14 | 403.51 | 8.11 | 1.02 | 0.00 | 602.02 | 1.09 | 20.33 | 0.71 | 1.11 | 1.10 | 4707.96 | 7117.79 | 2.50 | 3.66 | 0.36 | 1.08 | 0.00 | 0.85 | 7.78 | 1.29 | 0.71 | 6.13 | 3.62 | 3.22 | 3.57 |
| min | 18.00 | 102.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 30.00 | 1.00 | 1.00 | 1.00 | 1009.00 | 2094.00 | 0.00 | 11.00 | 3.00 | 1.00 | 80.00 | 0.00 | 0.00 | 0.00 | 1.00 | 0.00 | 0.00 | 0.00 | 0.00 |
| 25% | 30.00 | 465.00 | 2.00 | 2.00 | 1.00 | 491.25 | 2.00 | 48.00 | 2.00 | 1.00 | 2.00 | 2911.00 | 8047.00 | 1.00 | 12.00 | 3.00 | 2.00 | 80.00 | 0.00 | 6.00 | 2.00 | 2.00 | 3.00 | 2.00 | 0.00 | 2.00 |
| 50% | 36.00 | 802.00 | 7.00 | 3.00 | 1.00 | 1020.50 | 3.00 | 66.00 | 3.00 | 2.00 | 3.00 | 4919.00 | 14235.50 | 2.00 | 14.00 | 3.00 | 3.00 | 80.00 | 1.00 | 10.00 | 3.00 | 3.00 | 5.00 | 3.00 | 1.00 | 3.00 |
| 75% | 43.00 | 1157.00 | 14.00 | 4.00 | 1.00 | 1555.75 | 4.00 | 83.75 | 3.00 | 3.00 | 4.00 | 8379.00 | 20461.50 | 4.00 | 18.00 | 3.00 | 4.00 | 80.00 | 1.00 | 15.00 | 3.00 | 3.00 | 9.00 | 7.00 | 3.00 | 7.00 |
| max | 60.00 | 1499.00 | 29.00 | 5.00 | 1.00 | 2068.00 | 4.00 | 100.00 | 4.00 | 5.00 | 4.00 | 19999.00 | 26999.00 | 9.00 | 25.00 | 4.00 | 4.00 | 80.00 | 3.00 | 40.00 | 6.00 | 4.00 | 40.00 | 18.00 | 15.00 | 17.00 |
# print(f'Number of unique values \n')
# for column in df.columns :
# print(f'{column}: {df[column].nunique()} unique values')
# print('================================================')
for column in df.columns:
print(f'{column}:')
print(f' - {df[column].nunique()} unique values')
print(f' - Data type: {df[column].dtype}')
print('================================================')
Age: - 43 unique values - Data type: int64 ================================================ Attrition: - 2 unique values - Data type: object ================================================ BusinessTravel: - 3 unique values - Data type: object ================================================ DailyRate: - 886 unique values - Data type: int64 ================================================ Department: - 3 unique values - Data type: object ================================================ DistanceFromHome: - 29 unique values - Data type: int64 ================================================ Education: - 5 unique values - Data type: int64 ================================================ EducationField: - 6 unique values - Data type: object ================================================ EmployeeCount: - 1 unique values - Data type: int64 ================================================ EmployeeNumber: - 1470 unique values - Data type: int64 ================================================ EnvironmentSatisfaction: - 4 unique values - Data type: int64 ================================================ Gender: - 2 unique values - Data type: object ================================================ HourlyRate: - 71 unique values - Data type: int64 ================================================ JobInvolvement: - 4 unique values - Data type: int64 ================================================ JobLevel: - 5 unique values - Data type: int64 ================================================ JobRole: - 9 unique values - Data type: object ================================================ JobSatisfaction: - 4 unique values - Data type: int64 ================================================ MaritalStatus: - 3 unique values - Data type: object ================================================ MonthlyIncome: - 1349 unique values - Data type: int64 ================================================ MonthlyRate: - 1427 unique values - Data type: int64 ================================================ NumCompaniesWorked: - 10 unique values - Data type: int64 ================================================ Over18: - 1 unique values - Data type: object ================================================ OverTime: - 2 unique values - Data type: object ================================================ PercentSalaryHike: - 15 unique values - Data type: int64 ================================================ PerformanceRating: - 2 unique values - Data type: int64 ================================================ RelationshipSatisfaction: - 4 unique values - Data type: int64 ================================================ StandardHours: - 1 unique values - Data type: int64 ================================================ StockOptionLevel: - 4 unique values - Data type: int64 ================================================ TotalWorkingYears: - 40 unique values - Data type: int64 ================================================ TrainingTimesLastYear: - 7 unique values - Data type: int64 ================================================ WorkLifeBalance: - 4 unique values - Data type: int64 ================================================ YearsAtCompany: - 37 unique values - Data type: int64 ================================================ YearsInCurrentRole: - 19 unique values - Data type: int64 ================================================ YearsSinceLastPromotion: - 16 unique values - Data type: int64 ================================================ YearsWithCurrManager: - 18 unique values - Data type: int64 ================================================
## We notice that 'EmployeeCount', 'Over18', 'StandardHours' have only one unique values and 'EmployeeNumber' has 1470 unique values.
## This features aren't useful for us, So we are going to drop those columns.
df.drop(['EmployeeCount', 'EmployeeNumber', 'Over18', 'StandardHours'], axis="columns", inplace=True)
# we use inplace to modify original dataframe instead of creating new one.
object_col = []
for column in df.columns :
if df[column].dtype == object and len(df[column].unique()) <= 30:
object_col.append(column)
print(f'{column} : {df[column].unique()}')
print(df[column].value_counts())
print('===================================')
object_col.remove('Attrition')
Attrition : ['Yes' 'No'] No 1233 Yes 237 Name: Attrition, dtype: int64 =================================== BusinessTravel : ['Travel_Rarely' 'Travel_Frequently' 'Non-Travel'] Travel_Rarely 1043 Travel_Frequently 277 Non-Travel 150 Name: BusinessTravel, dtype: int64 =================================== Department : ['Sales' 'Research & Development' 'Human Resources'] Research & Development 961 Sales 446 Human Resources 63 Name: Department, dtype: int64 =================================== EducationField : ['Life Sciences' 'Other' 'Medical' 'Marketing' 'Technical Degree' 'Human Resources'] Life Sciences 606 Medical 464 Marketing 159 Technical Degree 132 Other 82 Human Resources 27 Name: EducationField, dtype: int64 =================================== Gender : ['Female' 'Male'] Male 882 Female 588 Name: Gender, dtype: int64 =================================== JobRole : ['Sales Executive' 'Research Scientist' 'Laboratory Technician' 'Manufacturing Director' 'Healthcare Representative' 'Manager' 'Sales Representative' 'Research Director' 'Human Resources'] Sales Executive 326 Research Scientist 292 Laboratory Technician 259 Manufacturing Director 145 Healthcare Representative 131 Manager 102 Sales Representative 83 Research Director 80 Human Resources 52 Name: JobRole, dtype: int64 =================================== MaritalStatus : ['Single' 'Married' 'Divorced'] Married 673 Single 470 Divorced 327 Name: MaritalStatus, dtype: int64 =================================== OverTime : ['Yes' 'No'] No 1054 Yes 416 Name: OverTime, dtype: int64 ===================================
len(object_col)
7
from sklearn.preprocessing import LabelEncoder
label = LabelEncoder()
df['Attrition'] = label.fit_transform(df.Attrition)
df
| Age | Attrition | BusinessTravel | DailyRate | Department | DistanceFromHome | Education | EducationField | EnvironmentSatisfaction | Gender | HourlyRate | JobInvolvement | JobLevel | JobRole | JobSatisfaction | MaritalStatus | MonthlyIncome | MonthlyRate | NumCompaniesWorked | OverTime | PercentSalaryHike | PerformanceRating | RelationshipSatisfaction | StockOptionLevel | TotalWorkingYears | TrainingTimesLastYear | WorkLifeBalance | YearsAtCompany | YearsInCurrentRole | YearsSinceLastPromotion | YearsWithCurrManager | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 41 | 1 | Travel_Rarely | 1102 | Sales | 1 | 2 | Life Sciences | 2 | Female | 94 | 3 | 2 | Sales Executive | 4 | Single | 5993 | 19479 | 8 | Yes | 11 | 3 | 1 | 0 | 8 | 0 | 1 | 6 | 4 | 0 | 5 |
| 1 | 49 | 0 | Travel_Frequently | 279 | Research & Development | 8 | 1 | Life Sciences | 3 | Male | 61 | 2 | 2 | Research Scientist | 2 | Married | 5130 | 24907 | 1 | No | 23 | 4 | 4 | 1 | 10 | 3 | 3 | 10 | 7 | 1 | 7 |
| 2 | 37 | 1 | Travel_Rarely | 1373 | Research & Development | 2 | 2 | Other | 4 | Male | 92 | 2 | 1 | Laboratory Technician | 3 | Single | 2090 | 2396 | 6 | Yes | 15 | 3 | 2 | 0 | 7 | 3 | 3 | 0 | 0 | 0 | 0 |
| 3 | 33 | 0 | Travel_Frequently | 1392 | Research & Development | 3 | 4 | Life Sciences | 4 | Female | 56 | 3 | 1 | Research Scientist | 3 | Married | 2909 | 23159 | 1 | Yes | 11 | 3 | 3 | 0 | 8 | 3 | 3 | 8 | 7 | 3 | 0 |
| 4 | 27 | 0 | Travel_Rarely | 591 | Research & Development | 2 | 1 | Medical | 1 | Male | 40 | 3 | 1 | Laboratory Technician | 2 | Married | 3468 | 16632 | 9 | No | 12 | 3 | 4 | 1 | 6 | 3 | 3 | 2 | 2 | 2 | 2 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1465 | 36 | 0 | Travel_Frequently | 884 | Research & Development | 23 | 2 | Medical | 3 | Male | 41 | 4 | 2 | Laboratory Technician | 4 | Married | 2571 | 12290 | 4 | No | 17 | 3 | 3 | 1 | 17 | 3 | 3 | 5 | 2 | 0 | 3 |
| 1466 | 39 | 0 | Travel_Rarely | 613 | Research & Development | 6 | 1 | Medical | 4 | Male | 42 | 2 | 3 | Healthcare Representative | 1 | Married | 9991 | 21457 | 4 | No | 15 | 3 | 1 | 1 | 9 | 5 | 3 | 7 | 7 | 1 | 7 |
| 1467 | 27 | 0 | Travel_Rarely | 155 | Research & Development | 4 | 3 | Life Sciences | 2 | Male | 87 | 4 | 2 | Manufacturing Director | 2 | Married | 6142 | 5174 | 1 | Yes | 20 | 4 | 2 | 1 | 6 | 0 | 3 | 6 | 2 | 0 | 3 |
| 1468 | 49 | 0 | Travel_Frequently | 1023 | Sales | 2 | 3 | Medical | 4 | Male | 63 | 2 | 2 | Sales Executive | 2 | Married | 5390 | 13243 | 2 | No | 14 | 3 | 4 | 0 | 17 | 3 | 2 | 9 | 6 | 0 | 8 |
| 1469 | 34 | 0 | Travel_Rarely | 628 | Research & Development | 8 | 3 | Medical | 2 | Male | 82 | 4 | 2 | Laboratory Technician | 3 | Married | 4404 | 10228 | 2 | No | 12 | 3 | 1 | 0 | 6 | 3 | 4 | 4 | 3 | 1 | 2 |
1470 rows × 31 columns
disc_col = []
for column in df.columns:
if df[column].dtypes != object and df[column].nunique() < 30:
print(f"{column} : {df[column].unique()}")
disc_col.append(column)
print("====================================")
disc_col.remove('Attrition')
Attrition : [1 0] ==================================== DistanceFromHome : [ 1 8 2 3 24 23 27 16 15 26 19 21 5 11 9 7 6 10 4 25 12 18 29 22 14 20 28 17 13] ==================================== Education : [2 1 4 3 5] ==================================== EnvironmentSatisfaction : [2 3 4 1] ==================================== JobInvolvement : [3 2 4 1] ==================================== JobLevel : [2 1 3 4 5] ==================================== JobSatisfaction : [4 2 3 1] ==================================== NumCompaniesWorked : [8 1 6 9 0 4 5 2 7 3] ==================================== PercentSalaryHike : [11 23 15 12 13 20 22 21 17 14 16 18 19 24 25] ==================================== PerformanceRating : [3 4] ==================================== RelationshipSatisfaction : [1 4 2 3] ==================================== StockOptionLevel : [0 1 3 2] ==================================== TrainingTimesLastYear : [0 3 2 5 1 4 6] ==================================== WorkLifeBalance : [1 3 2 4] ==================================== YearsInCurrentRole : [ 4 7 0 2 5 9 8 3 6 13 1 15 14 16 11 10 12 18 17] ==================================== YearsSinceLastPromotion : [ 0 1 3 2 7 4 8 6 5 15 9 13 12 10 11 14] ==================================== YearsWithCurrManager : [ 5 7 0 2 6 8 3 11 17 1 4 12 9 10 15 13 16 14] ====================================
cont_col = []
for column in df.columns:
if df[column].dtypes != object and df[column].nunique() > 30:
print(f"{column} : Minimum: {df[column].min()}, Maximum: {df[column].max()}")
cont_col.append(column)
print("====================================")
Age : Minimum: 18, Maximum: 60 ==================================== DailyRate : Minimum: 102, Maximum: 1499 ==================================== HourlyRate : Minimum: 30, Maximum: 100 ==================================== MonthlyIncome : Minimum: 1009, Maximum: 19999 ==================================== MonthlyRate : Minimum: 2094, Maximum: 26999 ==================================== TotalWorkingYears : Minimum: 0, Maximum: 40 ==================================== YearsAtCompany : Minimum: 0, Maximum: 40 ====================================
# distance from home vs attrition
df.hvplot.hist(y = 'DistanceFromHome', by = 'Attrition', subplot = False, width = 600, height = 300, bins = 30, color=['#c8e371', '#9571e3'])
WARNING:param.main: subplot option not found for hist plot with bokeh; similar options include: []
# plot of education vs attrition
df.hvplot.hist(y = 'Education', by = 'Attrition', subplot = False, width = 600, height = 300, color=['#c8e371', '#9571e3'])
WARNING:param.main: subplot option not found for hist plot with bokeh; similar options include: []
# relationship satisfaction vs attrition
df.hvplot.hist(y = 'RelationshipSatisfaction', by = 'Attrition', subplots = False, width = 600, height = 300, color=['#c8e371', '#9571e3'])
# environmental satisfaction vs attrition
df.hvplot.hist(y='EnvironmentSatisfaction', by='Attrition', subplots=False, width=600, height=300, color=['#c8e371', '#9571e3'])
# environmental satisfaction vs attrition
df.hvplot.hist( y = 'JobInvolvement', by = "Attrition", subplots = False, width = 600, height = 300, color=['#c8e371', '#9571e3'])
# job level vs attrition
df.hvplot.hist( y = 'JobLevel', by = 'Attrition', subplots = False, height = 300, weight = 600, color=['#c8e371', '#9571e3'])
WARNING:param.main: weight option not found for hist plot with bokeh; similar options include: ['height', 'min_height', 'max_height']
df.hvplot.hist( y = 'JobSatisfaction' , by = 'Attrition', subplots = False, height = 300, weight = 600, color=['#c8e371', '#9571e3'])
WARNING:param.main: weight option not found for hist plot with bokeh; similar options include: ['height', 'min_height', 'max_height']
df.hvplot.hist(y = 'NumCompaniesWorked', by = 'Attrition', subplots = False, width = 600, height = 300, color=['#c8e371', '#9571e3'])
df.hvplot.hist( y = 'PercentSalaryHike', by = 'Attrition' , subplots = False, width = 600, height = 300, color=['#c8e371', '#9571e3'])
df.hvplot.hist( y = 'StockOptionLevel', by = 'Attrition' , subplots = False, height = 300, width = 600, color=['#c8e371', '#9571e3'])
df.hvplot.hist( y = 'TrainingTimesLastYear' , by = 'Attrition', subplots = False, height = 300, width = 600, color=['#c8e371', '#9571e3'])
df.hvplot.hist( y = 'PerformanceRating', by = 'Attrition', subplots = False, height = 600, width = 300, color=['#c8e371', '#9571e3'])
df.hvplot.hist( y = 'Age', by = 'Attrition', subplots = False, height = 300 , width = 600, bins = 35, color=['#c8e371', '#9571e3'])
df.hvplot.hist( y = 'MonthlyIncome', by = 'Attrition', subplots = False, width = 600, height = 300, color=['#c8e371', '#9571e3'])
df.hvplot.hist( y = 'YearsAtCompany', by = 'Attrition', subplots = False, height = 300, width = 600, color=['#c8e371', '#9571e3'])
df.hvplot.hist( y = 'TotalWorkingYears', by = 'Attrition', subplots = False, height = 300, width = 600, color=['#c8e371', '#9571e3'])
df.hvplot.hist( y = 'YearsWithCurrManager', by = 'Attrition', subplots = False, height = 300, width = 600, color=['#c8e371', '#9571e3'])
# df.hvplot.hist( y = 'Gender', by = 'Attrition', subplots = False, height = 300, width = 600)
# Group by 'Gender' and 'Attrition' to get counts
gender_attrition = df.groupby('Attrition')['Gender'].value_counts().unstack()
gender_attrition.hvplot.bar(stacked=True, height=300, width=600, color=['#c8e371', '#9571e3'])
📌 Note It seems that EnvironmentSatisfaction, JobSatisfaction, PerformanceRating, and RelationshipSatisfaction features don't have big impact on the detrmination of Attrition of employees.
The workers with low JobLevel, MonthlyIncome, YearAtCompany, and TotalWorkingYears are more likely to quit there jobs. BusinessTravel : The workers who travel alot are more likely to quit then other employees.
Department : The worker in Research & Development are more likely to stay then the workers on other departement.
EducationField : The workers with Human Resources and Technical Degree are more likely to quit then employees from other fields of educations.
Gender : The Male are more likely to quit.
JobRole : The workers in Laboratory Technician, Sales Representative, and Human Resources are more likely to quit the workers in other positions.
MaritalStatus : The workers who have Single marital status are more likely to quit the Married, and Divorced.
OverTime : The workers who work more hours are likely to quit then others.
# Correlation Matrix
plt.figure(figsize=(30, 30))
sns.heatmap(df.corr(), annot=True, cmap="RdYlGn", annot_kws={"size":15})
<AxesSubplot:>
col = df.corr().nlargest(20, "Attrition").Attrition.index
plt.figure(figsize=(15, 15))
sns.heatmap(df[col].corr(), annot=True, cmap="RdYlGn", annot_kws={"size":10})
<AxesSubplot:>
# df.drop('Attrition', axis=1).corrwith(df.Attrition).hvplot.barh()
df.drop('Attrition', axis=1).corrwith(df.Attrition).hvplot.barh(color='#c8e371')
# Transform categorical data into dummies
dummy_col = [column for column in df.drop('Attrition', axis=1).columns if df[column].nunique() < 20]
data = pd.get_dummies(df, columns=dummy_col, drop_first=True, dtype='uint8')
data.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 1470 entries, 0 to 1469 Columns: 136 entries, Age to YearsWithCurrManager_17 dtypes: int64(9), uint8(127) memory usage: 285.8 KB
print(data.shape)
#Remove Duplicate Features
data = data.T.drop_duplicates()
data = data.T
#Remove Duplicate Rows
data.drop_duplicates(inplace = True)
print(data.shape)
(1470, 136) (1470, 136)
data.shape
(1470, 136)
# data.drop('Attrition', axis=1).corrwith(data.Attrition).sort_values().plot(kind='barh', figsize=(10, 30))
data.drop('Attrition', axis=1).corrwith(data.Attrition).sort_values().plot(kind='barh', figsize=(10, 30), color='#c8e371')
<AxesSubplot:>
feature_correlation = data.drop('Attrition', axis=1).corrwith(data.Attrition).sort_values()
model_col = feature_correlation[np.abs(feature_correlation) > 0.02].index
len(model_col)
92
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
X = data.drop('Attrition', axis = 1)
y = data.Attrition
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size = 0.3, random_state = 42, stratify = y)
scaler = StandardScaler()
X_train_std = scaler.fit_transform(X_train)
X_test_std = scaler.transform(X_test)
X_std = scaler.transform(X)
def feature_imp(df, model):
fi = pd.DataFrame()
fi["feature"] = df.columns
fi["importance"] = model.feature_importances_
return fi.sort_values(by="importance", ascending=False)
We have an imbalanced data, so if we predict that all our employees will stay we'll have an accuracy of 83.90%.
y_test.value_counts()[0] / y_test.shape[0]
0.8390022675736961
stay = (y_train.value_counts()[0] / y_train.shape)[0]
leave = (y_train.value_counts()[1] / y_train.shape)[0]
print("===============TRAIN=================")
print(f"Staying Rate: {stay * 100:.2f}%")
print(f"Leaving Rate: {leave * 100 :.2f}%")
stay = (y_test.value_counts()[0] / y_test.shape)[0]
leave = (y_test.value_counts()[1] / y_test.shape)[0]
print("===============TEST=================")
print(f"Staying Rate: {stay * 100:.2f}%")
print(f"Leaving Rate: {leave * 100 :.2f}%")
===============TRAIN================= Staying Rate: 83.87% Leaving Rate: 16.13% ===============TEST================= Staying Rate: 83.90% Leaving Rate: 16.10%
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, roc_auc_score
def evaluate(model, X_train, X_test, y_train, y_test):
y_test_pred = model.predict(X_test)
y_train_pred = model.predict(X_train)
print("TRAINING RESULTS: \n===============================")
clf_report = pd.DataFrame(classification_report(y_train, y_train_pred, output_dict=True))
print(f"CONFUSION MATRIX:\n{confusion_matrix(y_train, y_train_pred)}")
print(f"ACCURACY SCORE:\n{accuracy_score(y_train, y_train_pred):.4f}")
print(f"CLASSIFICATION REPORT:\n{clf_report}")
print("TESTING RESULTS: \n===============================")
clf_report = pd.DataFrame(classification_report(y_test, y_test_pred, output_dict=True))
print(f"CONFUSION MATRIX:\n{confusion_matrix(y_test, y_test_pred)}")
print(f"ACCURACY SCORE:\n{accuracy_score(y_test, y_test_pred):.4f}")
print(f"CLASSIFICATION REPORT:\n{clf_report}")
from sklearn.linear_model import LogisticRegression
lr_clf = LogisticRegression(solver = 'liblinear', penalty = 'l1')
lr_clf.fit(X_train_std, y_train)
evaluate(lr_clf, X_train_std, X_test_std, y_train, y_test)
TRAINING RESULTS:
===============================
CONFUSION MATRIX:
[[849 14]
[ 59 107]]
ACCURACY SCORE:
0.9291
CLASSIFICATION REPORT:
0 1 accuracy macro avg weighted avg
precision 0.94 0.88 0.93 0.91 0.93
recall 0.98 0.64 0.93 0.81 0.93
f1-score 0.96 0.75 0.93 0.85 0.92
support 863.00 166.00 0.93 1029.00 1029.00
TESTING RESULTS:
===============================
CONFUSION MATRIX:
[[348 22]
[ 43 28]]
ACCURACY SCORE:
0.8526
CLASSIFICATION REPORT:
0 1 accuracy macro avg weighted avg
precision 0.89 0.56 0.85 0.73 0.84
recall 0.94 0.39 0.85 0.67 0.85
f1-score 0.91 0.46 0.85 0.69 0.84
support 370.00 71.00 0.85 441.00 441.00
from sklearn.metrics import precision_recall_curve, roc_curve
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
plt.plot(thresholds, recalls[:-1], "g--", label="Recall")
plt.xlabel("Threshold")
plt.legend(loc="upper left")
plt.title("Precision/Recall Tradeoff")
def plot_roc_curve(fpr, tpr, label=None):
plt.plot(fpr, tpr, linewidth=2, label=label)
plt.plot([0, 1], [0, 1], "k--")
plt.axis([0, 1, 0, 1])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
precisions, recalls, thresholds = precision_recall_curve(y_test, lr_clf.predict(X_test_std))
plt.figure(figsize=(14, 25))
plt.subplot(4, 2, 1)
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.subplot(4, 2, 2)
plt.plot(precisions, recalls)
plt.xlabel("Precision")
plt.ylabel("Recall")
plt.title("PR Curve: precisions/recalls tradeoff");
plt.subplot(4, 2, 3)
fpr, tpr, thresholds = roc_curve(y_test, lr_clf.predict(X_test_std))
plot_roc_curve(fpr, tpr)
scores_dict = {
'Logistic Regression': {
'Train': roc_auc_score(y_train, lr_clf.predict(X_train)),
'Test': roc_auc_score(y_test, lr_clf.predict(X_test)),
},
}
/Users/pritish/opt/anaconda3/lib/python3.9/site-packages/sklearn/base.py:443: UserWarning: X has feature names, but LogisticRegression was fitted without feature names warnings.warn( /Users/pritish/opt/anaconda3/lib/python3.9/site-packages/sklearn/base.py:443: UserWarning: X has feature names, but LogisticRegression was fitted without feature names warnings.warn(
# Random Forest Classifier
from sklearn.ensemble import RandomForestClassifier
rf_clf = RandomForestClassifier(n_estimators=100, bootstrap=False,
# class_weight={0:stay, 1:leave}
)
rf_clf.fit(X_train, y_train)
evaluate(rf_clf, X_train, X_test, y_train, y_test)
TRAINING RESULTS:
===============================
CONFUSION MATRIX:
[[863 0]
[ 0 166]]
ACCURACY SCORE:
1.0000
CLASSIFICATION REPORT:
0 1 accuracy macro avg weighted avg
precision 1.00 1.00 1.00 1.00 1.00
recall 1.00 1.00 1.00 1.00 1.00
f1-score 1.00 1.00 1.00 1.00 1.00
support 863.00 166.00 1.00 1029.00 1029.00
TESTING RESULTS:
===============================
CONFUSION MATRIX:
[[355 15]
[ 63 8]]
ACCURACY SCORE:
0.8231
CLASSIFICATION REPORT:
0 1 accuracy macro avg weighted avg
precision 0.85 0.35 0.82 0.60 0.77
recall 0.96 0.11 0.82 0.54 0.82
f1-score 0.90 0.17 0.82 0.54 0.78
support 370.00 71.00 0.82 441.00 441.00
param_grid = dict(
n_estimators= [100, 500, 900],
max_features= ['auto', 'sqrt'],
max_depth= [2, 3, 5, 10, 15, None],
min_samples_split= [2, 5, 10],
min_samples_leaf= [1, 2, 4],
bootstrap= [True, False]
)
rf_clf = RandomForestClassifier(random_state=42)
search = GridSearchCV(rf_clf, param_grid=param_grid, scoring='roc_auc', cv=5, verbose=1, n_jobs=-1)
search.fit(X_train, y_train)
rf_clf = RandomForestClassifier(**search.best_params_, random_state=42)
rf_clf.fit(X_train, y_train)
evaluate(rf_clf, X_train, X_test, y_train, y_test)
Fitting 5 folds for each of 648 candidates, totalling 3240 fits
/usr/local/lib/python3.10/dist-packages/joblib/externals/loky/process_executor.py:752: UserWarning: A worker stopped while some jobs were given to the executor. This can be caused by a too short worker timeout or by a memory leak.
warnings.warn(
/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py:425: FitFailedWarning:
1620 fits failed out of a total of 3240.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.
Below are more details about the failures:
--------------------------------------------------------------------------------
1612 fits failed with the following error:
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py", line 729, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "/usr/local/lib/python3.10/dist-packages/sklearn/base.py", line 1145, in wrapper
estimator._validate_params()
File "/usr/local/lib/python3.10/dist-packages/sklearn/base.py", line 638, in _validate_params
validate_parameter_constraints(
File "/usr/local/lib/python3.10/dist-packages/sklearn/utils/_param_validation.py", line 96, in validate_parameter_constraints
raise InvalidParameterError(
sklearn.utils._param_validation.InvalidParameterError: The 'max_features' parameter of RandomForestClassifier must be an int in the range [1, inf), a float in the range (0.0, 1.0], a str among {'log2', 'sqrt'} or None. Got 'auto' instead.
--------------------------------------------------------------------------------
8 fits failed with the following error:
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_validation.py", line 729, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "/usr/local/lib/python3.10/dist-packages/sklearn/base.py", line 1145, in wrapper
estimator._validate_params()
File "/usr/local/lib/python3.10/dist-packages/sklearn/base.py", line 638, in _validate_params
validate_parameter_constraints(
File "/usr/local/lib/python3.10/dist-packages/sklearn/utils/_param_validation.py", line 96, in validate_parameter_constraints
raise InvalidParameterError(
sklearn.utils._param_validation.InvalidParameterError: The 'max_features' parameter of RandomForestClassifier must be an int in the range [1, inf), a float in the range (0.0, 1.0], a str among {'sqrt', 'log2'} or None. Got 'auto' instead.
warnings.warn(some_fits_failed_message, FitFailedWarning)
/usr/local/lib/python3.10/dist-packages/sklearn/model_selection/_search.py:979: UserWarning: One or more of the test scores are non-finite: [ nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan 0.76945807 0.78061037 0.78091293
0.76924768 0.78040079 0.78070316 0.7689374 0.78033446 0.78046313
0.77060077 0.78121102 0.78175301 0.77095109 0.78124523 0.78168233
0.77064081 0.78159552 0.78147319 0.7705419 0.78267627 0.78280376
0.7705419 0.78267627 0.78280376 0.7705419 0.78267522 0.78283899
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan 0.77168264 0.78333931 0.78695069
0.77171668 0.78341497 0.78729476 0.77299012 0.78406163 0.78688355
0.77103676 0.78315414 0.78722823 0.77124092 0.78328948 0.78715711
0.77188648 0.78374423 0.78715877 0.77036855 0.78628491 0.78786907
0.77036855 0.78628491 0.78786907 0.77220986 0.78549676 0.78789871
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan 0.77230522 0.78584482 0.78758594
0.78144863 0.79017276 0.78942005 0.77745808 0.78908126 0.79108479
0.77983125 0.79230404 0.79167581 0.78250067 0.79232745 0.79034278
0.77727265 0.79116837 0.79150381 0.77934715 0.78726413 0.79031774
0.77934715 0.78726413 0.79031774 0.77880076 0.79015098 0.79213285
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan 0.76901723 0.78577988 0.78814884
0.78507252 0.79171173 0.7907192 0.78079524 0.78923858 0.78906362
0.79058827 0.79002879 0.79010544 0.77558514 0.78980142 0.79006765
0.77567566 0.78965635 0.78939109 0.7783401 0.79287343 0.79268588
0.7783401 0.79287343 0.79268588 0.78280221 0.78899099 0.79108344
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan 0.7696008 0.78611927 0.78735364
0.77911427 0.79286826 0.7908303 0.781364 0.7887145 0.78702378
0.78333076 0.78674454 0.78517261 0.78676736 0.78955173 0.78777395
0.78433675 0.78911509 0.78989902 0.7821321 0.78927569 0.78966273
0.7821321 0.78927569 0.78966273 0.77209073 0.78383785 0.78582867
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan 0.75881881 0.78160745 0.78536603
0.7746436 0.78729327 0.78781705 0.77662813 0.78655358 0.78634342
0.78424537 0.78804707 0.78692457 0.78724023 0.78888432 0.7903743
0.78504095 0.79092195 0.79179634 0.77857075 0.78934363 0.78970065
0.77857075 0.78934363 0.78970065 0.77427988 0.78359934 0.78562911
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan 0.77254439 0.782905 0.7817907
0.77271956 0.78286997 0.78182573 0.77268452 0.78286997 0.78175567
0.77220299 0.78262622 0.78171856 0.77220299 0.78262622 0.78171856
0.77213293 0.78259119 0.7816485 0.77152742 0.78276074 0.78164313
0.77152742 0.78276074 0.78164313 0.77152742 0.78276074 0.78157306
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan 0.77255438 0.78412625 0.78540559
0.77290139 0.78517325 0.78586329 0.7732148 0.78402075 0.78522943
0.77332971 0.78482146 0.78557744 0.77363963 0.78534632 0.78592299
0.77405399 0.78395026 0.78546302 0.77635455 0.78536288 0.78656829
0.77635455 0.78536288 0.78656829 0.77523212 0.78547087 0.78611744
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan 0.78792026 0.78948638 0.78869155
0.78366667 0.78982486 0.7870584 0.78621309 0.78693362 0.78717114
0.78452429 0.78956411 0.78816273 0.78197836 0.7910633 0.78820123
0.78814566 0.78863698 0.78782447 0.78189284 0.78833395 0.78878584
0.78189284 0.78833395 0.78878584 0.78254356 0.78879489 0.78858541
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan 0.78282385 0.79367384 0.79144828
0.77340389 0.78858477 0.7877291 0.78259949 0.7861631 0.78792754
0.78361157 0.78803023 0.79018984 0.77691365 0.7861846 0.78690029
0.78181173 0.78967178 0.78834499 0.7830256 0.78671573 0.78716218
0.7830256 0.78671573 0.78716218 0.77449277 0.78731806 0.78996564
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan 0.77615414 0.78162344 0.78436517
0.77063512 0.78412794 0.78386215 0.7818307 0.79104673 0.79206892
0.78455634 0.78745586 0.78764667 0.77606887 0.78744487 0.78728296
0.77905598 0.78760346 0.78784128 0.77337957 0.78782359 0.7876929
0.77337957 0.78782359 0.7876929 0.773959 0.7837121 0.78612142
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan nan nan nan
nan nan nan 0.77286162 0.78752062 0.78659648
0.76139451 0.78295261 0.78431658 0.77926562 0.7898348 0.79194034
0.78564 0.7875693 0.78759333 0.77498385 0.78628749 0.78676072
0.7791113 0.78560024 0.78701348 0.77235566 0.78713443 0.78734196
0.77235566 0.78713443 0.78734196 0.77666725 0.78568421 0.7866687 ]
warnings.warn(
TRAINING RESULTS:
===============================
CONFUSION MATRIX:
[[863 0]
[ 15 151]]
ACCURACY SCORE:
0.9854
CLASSIFICATION REPORT:
0 1 accuracy macro avg weighted avg
precision 0.98 1.00 0.99 0.99 0.99
recall 1.00 0.91 0.99 0.95 0.99
f1-score 0.99 0.95 0.99 0.97 0.99
support 863.00 166.00 0.99 1029.00 1029.00
TESTING RESULTS:
===============================
CONFUSION MATRIX:
[[360 10]
[ 63 8]]
ACCURACY SCORE:
0.8345
CLASSIFICATION REPORT:
0 1 accuracy macro avg weighted avg
precision 0.85 0.44 0.83 0.65 0.79
recall 0.97 0.11 0.83 0.54 0.83
f1-score 0.91 0.18 0.83 0.54 0.79
support 370.00 71.00 0.83 441.00 441.00
precisions, recalls, thresholds = precision_recall_curve(y_test, rf_clf.predict(X_test))
plt.figure(figsize=(14, 25))
plt.subplot(4, 2, 1)
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.subplot(4, 2, 2)
plt.plot(precisions, recalls)
plt.xlabel("Precision")
plt.ylabel("Recall")
plt.title("PR Curve: precisions/recalls tradeoff");
plt.subplot(4, 2, 3)
fpr, tpr, thresholds = roc_curve(y_test, rf_clf.predict(X_test))
plot_roc_curve(fpr, tpr)
scores_dict['Random Forest'] = {
'Train': roc_auc_score(y_train, rf_clf.predict(X_train)),
'Test': roc_auc_score(y_test, rf_clf.predict(X_test)),
}
df = feature_imp(X, rf_clf)[:40]
df.set_index('feature', inplace=True)
df.plot(kind='barh', figsize=(10, 10), color='#c8e371')
plt.title('Feature Importance according to Random Forest')
Text(0.5, 1.0, 'Feature Importance according to Random Forest')
from sklearn.svm import SVC
svm_clf = SVC(kernel='linear')
svm_clf.fit(X_train_std, y_train)
evaluate(svm_clf, X_train_std, X_test_std, y_train, y_test)
TRAINING RESULTS:
===============================
CONFUSION MATRIX:
[[855 8]
[ 47 119]]
ACCURACY SCORE:
0.9466
CLASSIFICATION REPORT:
0 1 accuracy macro avg weighted avg
precision 0.95 0.94 0.95 0.94 0.95
recall 0.99 0.72 0.95 0.85 0.95
f1-score 0.97 0.81 0.95 0.89 0.94
support 863.00 166.00 0.95 1029.00 1029.00
TESTING RESULTS:
===============================
CONFUSION MATRIX:
[[345 25]
[ 44 27]]
ACCURACY SCORE:
0.8435
CLASSIFICATION REPORT:
0 1 accuracy macro avg weighted avg
precision 0.89 0.52 0.84 0.70 0.83
recall 0.93 0.38 0.84 0.66 0.84
f1-score 0.91 0.44 0.84 0.67 0.83
support 370.00 71.00 0.84 441.00 441.00
svm_clf = SVC(random_state=42)
param_grid = [
{'C': [1, 10, 100, 1000], 'kernel': ['linear']},
{'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001], 'kernel': ['rbf']}
]
search = GridSearchCV(svm_clf, param_grid=param_grid, scoring='roc_auc', cv=3, refit=True, verbose=1)
search.fit(X_train_std, y_train)
Fitting 3 folds for each of 12 candidates, totalling 36 fits
GridSearchCV(cv=3, estimator=SVC(random_state=42),
param_grid=[{'C': [1, 10, 100, 1000], 'kernel': ['linear']},
{'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001],
'kernel': ['rbf']}],
scoring='roc_auc', verbose=1)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. GridSearchCV(cv=3, estimator=SVC(random_state=42),
param_grid=[{'C': [1, 10, 100, 1000], 'kernel': ['linear']},
{'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001],
'kernel': ['rbf']}],
scoring='roc_auc', verbose=1)SVC(random_state=42)
SVC(random_state=42)
svm_clf = SVC(**search.best_params_)
svm_clf.fit(X_train_std, y_train)
evaluate(svm_clf, X_train_std, X_test_std, y_train, y_test)
TRAINING RESULTS:
===============================
CONFUSION MATRIX:
[[862 1]
[ 6 160]]
ACCURACY SCORE:
0.9932
CLASSIFICATION REPORT:
0 1 accuracy macro avg weighted avg
precision 0.99 0.99 0.99 0.99 0.99
recall 1.00 0.96 0.99 0.98 0.99
f1-score 1.00 0.98 0.99 0.99 0.99
support 863.00 166.00 0.99 1029.00 1029.00
TESTING RESULTS:
===============================
CONFUSION MATRIX:
[[346 24]
[ 42 29]]
ACCURACY SCORE:
0.8503
CLASSIFICATION REPORT:
0 1 accuracy macro avg weighted avg
precision 0.89 0.55 0.85 0.72 0.84
recall 0.94 0.41 0.85 0.67 0.85
f1-score 0.91 0.47 0.85 0.69 0.84
support 370.00 71.00 0.85 441.00 441.00
precisions, recalls, thresholds = precision_recall_curve(y_test, svm_clf.predict(X_test_std))
plt.figure(figsize=(14, 25))
plt.subplot(4, 2, 1)
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.subplot(4, 2, 2)
plt.plot(precisions, recalls)
plt.xlabel("Precision")
plt.ylabel("Recall")
plt.title("PR Curve: precisions/recalls tradeoff");
plt.subplot(4, 2, 3)
fpr, tpr, thresholds = roc_curve(y_test, svm_clf.predict(X_test_std))
plot_roc_curve(fpr, tpr)
scores_dict['Support Vector Machine'] = {
'Train': roc_auc_score(y_train, svm_clf.predict(X_train_std)),
'Test': roc_auc_score(y_test, svm_clf.predict(X_test_std)),
}
XGBoost Classifier
from xgboost import XGBClassifier
xgb_clf = XGBClassifier()
xgb_clf.fit(X_train, y_train)
evaluate(xgb_clf, X_train, X_test, y_train, y_test)
TRAINING RESULTS:
===============================
CONFUSION MATRIX:
[[863 0]
[ 0 166]]
ACCURACY SCORE:
1.0000
CLASSIFICATION REPORT:
0 1 accuracy macro avg weighted avg
precision 1.00 1.00 1.00 1.00 1.00
recall 1.00 1.00 1.00 1.00 1.00
f1-score 1.00 1.00 1.00 1.00 1.00
support 863.00 166.00 1.00 1029.00 1029.00
TESTING RESULTS:
===============================
CONFUSION MATRIX:
[[356 14]
[ 51 20]]
ACCURACY SCORE:
0.8526
CLASSIFICATION REPORT:
0 1 accuracy macro avg weighted avg
precision 0.87 0.59 0.85 0.73 0.83
recall 0.96 0.28 0.85 0.62 0.85
f1-score 0.92 0.38 0.85 0.65 0.83
support 370.00 71.00 0.85 441.00 441.00
scores_dict['XGBoost'] = {
'Train': roc_auc_score(y_train, xgb_clf.predict(X_train)),
'Test': roc_auc_score(y_test, xgb_clf.predict(X_test)),
}
precisions, recalls, thresholds = precision_recall_curve(y_test, xgb_clf.predict(X_test))
plt.figure(figsize=(14, 25))
plt.subplot(4, 2, 1)
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.subplot(4, 2, 2)
plt.plot(precisions, recalls)
plt.xlabel("Precision")
plt.ylabel("Recall")
plt.title("PR Curve: precisions/recalls tradeoff");
plt.subplot(4, 2, 3)
fpr, tpr, thresholds = roc_curve(y_test, xgb_clf.predict(X_test))
plot_roc_curve(fpr, tpr)
df = feature_imp(X, xgb_clf)[:35]
df.set_index('feature', inplace=True)
df.plot(kind='barh', figsize=(10, 8))
plt.title('Feature Importance according to XGBoost')
Text(0.5, 1.0, 'Feature Importance according to XGBoost')
LightGBM
from lightgbm import LGBMClassifier
lgb_clf = LGBMClassifier()
lgb_clf.fit(X_train, y_train)
evaluate(lgb_clf, X_train, X_test, y_train, y_test)
/usr/local/lib/python3.10/dist-packages/dask/dataframe/__init__.py:42: FutureWarning: Dask dataframe query planning is disabled because dask-expr is not installed. You can install it with `pip install dask[dataframe]` or `conda install dask`. This will raise in a future version. warnings.warn(msg, FutureWarning)
[LightGBM] [Warning] Found whitespace in feature_names, replace with underlines
[LightGBM] [Info] Number of positive: 166, number of negative: 863
[LightGBM] [Info] Auto-choosing row-wise multi-threading, the overhead of testing was 0.000595 seconds.
You can set `force_row_wise=true` to remove the overhead.
And if memory is not enough, you can set `force_col_wise=true`.
[LightGBM] [Info] Total Bins 1176
[LightGBM] [Info] Number of data points in the train set: 1029, number of used features: 108
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.161322 -> initscore=-1.648427
[LightGBM] [Info] Start training from score -1.648427
TRAINING RESULTS:
===============================
CONFUSION MATRIX:
[[863 0]
[ 0 166]]
ACCURACY SCORE:
1.0000
CLASSIFICATION REPORT:
0 1 accuracy macro avg weighted avg
precision 1.00 1.00 1.00 1.00 1.00
recall 1.00 1.00 1.00 1.00 1.00
f1-score 1.00 1.00 1.00 1.00 1.00
support 863.00 166.00 1.00 1029.00 1029.00
TESTING RESULTS:
===============================
CONFUSION MATRIX:
[[357 13]
[ 53 18]]
ACCURACY SCORE:
0.8503
CLASSIFICATION REPORT:
0 1 accuracy macro avg weighted avg
precision 0.87 0.58 0.85 0.73 0.82
recall 0.96 0.25 0.85 0.61 0.85
f1-score 0.92 0.35 0.85 0.63 0.82
support 370.00 71.00 0.85 441.00 441.00
precisions, recalls, thresholds = precision_recall_curve(y_test, lgb_clf.predict(X_test))
plt.figure(figsize=(14, 25))
plt.subplot(4, 2, 1)
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.subplot(4, 2, 2)
plt.plot(precisions, recalls)
plt.xlabel("Precision")
plt.ylabel("Recall")
plt.title("PR Curve: precisions/recalls tradeoff");
plt.subplot(4, 2, 3)
fpr, tpr, thresholds = roc_curve(y_test, lgb_clf.predict(X_test))
plot_roc_curve(fpr, tpr)
scores_dict['LightGBM'] = {
'Train': roc_auc_score(y_train, lgb_clf.predict(X_train)),
'Test': roc_auc_score(y_test, lgb_clf.predict(X_test)),
}
!pip install catboost
Collecting catboost Downloading catboost-1.2.5-cp310-cp310-manylinux2014_x86_64.whl.metadata (1.2 kB) Requirement already satisfied: graphviz in /usr/local/lib/python3.10/dist-packages (from catboost) (0.20.3) Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from catboost) (3.7.1) Requirement already satisfied: numpy>=1.16.0 in /usr/local/lib/python3.10/dist-packages (from catboost) (1.26.4) Requirement already satisfied: pandas>=0.24 in /usr/local/lib/python3.10/dist-packages (from catboost) (2.1.4) Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from catboost) (1.13.1) Requirement already satisfied: plotly in /usr/local/lib/python3.10/dist-packages (from catboost) (5.15.0) Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from catboost) (1.16.0) Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24->catboost) (2.8.2) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24->catboost) (2024.1) Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24->catboost) (2024.1) Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->catboost) (1.2.1) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->catboost) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->catboost) (4.53.1) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->catboost) (1.4.5) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->catboost) (24.1) Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->catboost) (9.4.0) Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->catboost) (3.1.4) Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from plotly->catboost) (9.0.0) Downloading catboost-1.2.5-cp310-cp310-manylinux2014_x86_64.whl (98.2 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 98.2/98.2 MB 7.9 MB/s eta 0:00:00 Installing collected packages: catboost Successfully installed catboost-1.2.5
from catboost import CatBoostClassifier
cb_clf = CatBoostClassifier()
cb_clf.fit(X_train, y_train, verbose=0)
evaluate(cb_clf, X_train, X_test, y_train, y_test)
TRAINING RESULTS:
===============================
CONFUSION MATRIX:
[[863 0]
[ 16 150]]
ACCURACY SCORE:
0.9845
CLASSIFICATION REPORT:
0 1 accuracy macro avg weighted avg
precision 0.98 1.00 0.98 0.99 0.98
recall 1.00 0.90 0.98 0.95 0.98
f1-score 0.99 0.95 0.98 0.97 0.98
support 863.00 166.00 0.98 1029.00 1029.00
TESTING RESULTS:
===============================
CONFUSION MATRIX:
[[361 9]
[ 57 14]]
ACCURACY SCORE:
0.8503
CLASSIFICATION REPORT:
0 1 accuracy macro avg weighted avg
precision 0.86 0.61 0.85 0.74 0.82
recall 0.98 0.20 0.85 0.59 0.85
f1-score 0.92 0.30 0.85 0.61 0.82
support 370.00 71.00 0.85 441.00 441.00
precisions, recalls, thresholds = precision_recall_curve(y_test, cb_clf.predict(X_test))
plt.figure(figsize=(14, 25))
plt.subplot(4, 2, 1)
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.subplot(4, 2, 2)
plt.plot(precisions, recalls)
plt.xlabel("Precision")
plt.ylabel("Recall")
plt.title("PR Curve: precisions/recalls tradeoff");
plt.subplot(4, 2, 3)
fpr, tpr, thresholds = roc_curve(y_test, cb_clf.predict(X_test))
plot_roc_curve(fpr, tpr)
scores_dict['CatBoost'] = {
'Train': roc_auc_score(y_train, cb_clf.predict(X_train)),
'Test': roc_auc_score(y_test, cb_clf.predict(X_test)),
}
AdaBoost
from sklearn.ensemble import AdaBoostClassifier
ab_clf = AdaBoostClassifier()
ab_clf.fit(X_train, y_train)
evaluate(ab_clf, X_train, X_test, y_train, y_test)
TRAINING RESULTS:
===============================
CONFUSION MATRIX:
[[843 20]
[ 88 78]]
ACCURACY SCORE:
0.8950
CLASSIFICATION REPORT:
0 1 accuracy macro avg weighted avg
precision 0.91 0.80 0.90 0.85 0.89
recall 0.98 0.47 0.90 0.72 0.90
f1-score 0.94 0.59 0.90 0.77 0.88
support 863.00 166.00 0.90 1029.00 1029.00
TESTING RESULTS:
===============================
CONFUSION MATRIX:
[[344 26]
[ 52 19]]
ACCURACY SCORE:
0.8231
CLASSIFICATION REPORT:
0 1 accuracy macro avg weighted avg
precision 0.87 0.42 0.82 0.65 0.80
recall 0.93 0.27 0.82 0.60 0.82
f1-score 0.90 0.33 0.82 0.61 0.81
support 370.00 71.00 0.82 441.00 441.00
precisions, recalls, thresholds = precision_recall_curve(y_test, ab_clf.predict(X_test))
plt.figure(figsize=(14, 25))
plt.subplot(4, 2, 1)
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.subplot(4, 2, 2)
plt.plot(precisions, recalls)
plt.xlabel("Precision")
plt.ylabel("Recall")
plt.title("PR Curve: precisions/recalls tradeoff");
plt.subplot(4, 2, 3)
fpr, tpr, thresholds = roc_curve(y_test, ab_clf.predict(X_test))
plot_roc_curve(fpr, tpr)
scores_dict['AdaBoost'] = {
'Train': roc_auc_score(y_train, ab_clf.predict(X_train)),
'Test': roc_auc_score(y_test, ab_clf.predict(X_test)),
}
Comparing Models Prerformance
ml_models = {
'Random Forest': rf_clf,
'XGBoost': xgb_clf,
'Logistic Regression': lr_clf,
'Support Vector Machine': svm_clf,
'LightGBM': lgb_clf,
'CatBoost': cb_clf,
'AdaBoost': ab_clf
}
for model in ml_models:
print(f"{model.upper():{30}} roc_auc_score: {roc_auc_score(y_test, ml_models[model].predict(X_test)):.3f}")
RANDOM FOREST roc_auc_score: 0.543 XGBOOST roc_auc_score: 0.622 LOGISTIC REGRESSION roc_auc_score: 0.546 SUPPORT VECTOR MACHINE roc_auc_score: 0.500 LIGHTGBM roc_auc_score: 0.609 CATBOOST roc_auc_score: 0.586 ADABOOST roc_auc_score: 0.599
/usr/local/lib/python3.10/dist-packages/sklearn/base.py:458: UserWarning: X has feature names, but LogisticRegression was fitted without feature names warnings.warn( /usr/local/lib/python3.10/dist-packages/sklearn/base.py:458: UserWarning: X has feature names, but SVC was fitted without feature names warnings.warn(
import matplotlib.pyplot as plt
# Assuming scores_df is your DataFrame
scores_df.plot(kind='barh', figsize=(15, 8))
# Add titles and labels for clarity
plt.title('Scores')
plt.xlabel('Score Value')
plt.ylabel('Categories')
# Show the plot
plt.show()